{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_13B_path = 'Your Path'\n",
    "llama_13B = AutoModelForCausalLM.from_pretrained(llama_13B_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(llama_13B_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INTERVAL = 1\n",
    "MERGE_LAYERS = 7\n",
    "HIGHEST_LAY = 39\n",
    "LOWEST_LAY = 0\n",
    "THRESHOLD = 0.45"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "def merge_layers_return_model(model, merge_base_lay, merge_layer_num):\n",
    "   \n",
    "    merge_layer_num = min(merge_layer_num, len(model.model.layers) - merge_base_lay - 1)\n",
    "    \n",
    "    model_copy = deepcopy(model)\n",
    "    for diff_lay in range(merge_base_lay+1, merge_base_lay+1+merge_layer_num):      \n",
    "        # gate_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.gate_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data\n",
    "        )\n",
    "        # down_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.down_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data\n",
    "        )\n",
    "        # up_proj\n",
    "        model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].mlp.up_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data\n",
    "        )\n",
    "        \n",
    "\n",
    "        # q_proj\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.q_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.q_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.q_proj.weight.data\n",
    "        )\n",
    "\n",
    "        # k_proj\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.k_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.k_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.k_proj.weight.data\n",
    "        ) \n",
    "    \n",
    "        # v_proj\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.v_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.v_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.v_proj.weight.data\n",
    "        )\n",
    "    \n",
    "        # o_proj\n",
    "        model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data.add_(\n",
    "            model.model.layers[diff_lay].self_attn.o_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data\n",
    "        )        \n",
    "                       \n",
    "    for diff_lay in range(merge_base_lay+merge_layer_num, merge_base_lay, -1):\n",
    "\n",
    "        del(model_copy.model.layers[diff_lay])\n",
    "    return model_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "llama_13B_copy_to_compress = copy.deepcopy(llama_13B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def cal_last_hidden_sim(model1, model2, tokenizer, sents):\n",
    "    sim_ls = []\n",
    "    for s in sents:\n",
    "        encoded_inputs = tokenizer(s, return_tensors='pt')\n",
    "        with torch.no_grad():\n",
    "            outputs1 = model1(**encoded_inputs, output_hidden_states=True)\n",
    "        hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        with torch.no_grad():\n",
    "            outputs2 = model2(**encoded_inputs, output_hidden_states=True)\n",
    "        hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))\n",
    "    sim_ls = [i.item() for i in sim_ls]\n",
    "    print(sim_ls, np.mean(sim_ls))\n",
    "    return np.mean(sim_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lay = HIGHEST_LAY - MERGE_LAYERS\n",
    "last_merge_flag = False\n",
    "\n",
    "sents = []\n",
    "en_wiki_selected = ['Mouron () is a commune in the Arde',\n",
    " 'The 81st Mechanised Brigade () is a mechanised brigade of the Romanian Land Force',\n",
    " 'There are 18 National Natural Landmarks in the U.S. state of Washington, out of nearly',\n",
    " 'Torreorgaz is a municipality in the',\n",
    " 'Copa Libertadores 1973 was won by defending champions Independiente of A']\n",
    "\n",
    "# zh_wiki_selected = ['月桃   \\xa0\\xa0月桃月桃属草本，单叶，互生，具',\n",
    "#  '法国立贝尔洁白牙贴  目录产品成份：产品功效：用法用量：注意事项：产品禁忌：不良反应：规\\xa0 \\xa0 格：医疗器械注册号：产品执行标准：生产许可证号：授权监制：生产企业：',\n",
    "#  'TIMKEN 641/632-B轴承  目录TIMK',\n",
    "#  '天然碳化物质微结构研究  目录图书信息内容简介  图书信息作\\u3000\\u3000者： 冯有利 著 \\n丛 书 名：\\xa0\\xa0出 版 社： 地质出版社 ISBN：9787116059771 出版时间',\n",
    "#  'V字领衣服  目录基本信息']\n",
    "\n",
    "sents.extend(en_wiki_selected)\n",
    "# sents.extend(zh_wiki_selected)\n",
    "\n",
    "\n",
    "while lay >= LOWEST_LAY:\n",
    "    print(lay)\n",
    "    print('current model layer', len(llama_13B_copy_to_compress.model.layers))\n",
    "    tmp_merged_model = merge_layers_return_model(llama_13B_copy_to_compress, lay, MERGE_LAYERS-1)\n",
    "    sim_value = cal_last_hidden_sim(llama_13B, tmp_merged_model, tokenizer, sents)\n",
    "    if sim_value > THRESHOLD:\n",
    "        llama_13B_copy_to_compress = tmp_merged_model\n",
    "        lay -= INTERVAL\n",
    "        if lay >= len(llama_13B_copy_to_compress.model.layers):\n",
    "            lay = len(llama_13B_copy_to_compress.model.layers) - 1 - MERGE_LAYERS\n",
    "    else:\n",
    "        lay -= 1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_13B_copy_to_compress.config.num_hidden_layers = len(llama_13B_copy_to_compress.model.layers)\n",
    "llama_13B_copy_to_compress"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
